# import python package
from os.path import join

def build_interval(chr_len_file, threads = 50, debug=False):
    """Build interval based on threads
    """

    from collections import defaultdict 
    # get the contig length and contig size
    chr_lens = open(chr_len_file, "r").readlines()
    chr_num = len(chr_lens)
    # initize the list for store 
    chr_name = [ '' for i in range(chr_num) ] 
    chr_len  = [ '' for i in range(chr_num) ]

    for i in range(chr_num):
        cn,cl = chr_lens[i].strip().split()
        chr_name[i] = cn
        chr_len[i] = int(cl)

    total_len = sum(chr_len)
    interval_len = total_len // threads 
    interval_dict = defaultdict(list)
    interval_name_dict = defaultdict(list)
    j = 0
    for i in range(chr_num):
        cur_len = interval_dict[j]
        if cur_len == []:
            interval_dict[j].append(chr_len[i])
            interval_name_dict[j].append(chr_name[i])
        elif sum(cur_len) < interval_len:
            interval_dict[j].append(chr_len[i])
            interval_name_dict[j].append(chr_name[i])
        else:
            j += 1
            interval_dict[j].append(chr_len[i])
            interval_name_dict[j].append(chr_name[i])
    if debug == True:
        print("Debug Info")
        for i in range(j):
            print("{}: {} \t {} ".format(i, sum(interval_dict[i]), " ".join(interval_name_dict[i] )))

    return interval_name_dict


# configuration
configfile: "config.yaml"
FASTA    = config['FASTA']
INDEX    = config['INDEX']
PLOIDY   = config['PLOIDY']
CHRFILE  = config['CHRFILE']
VCFJOBS  = config['VCFJOBS']
snpEff   = config["snpEff"]
SnpEffDB =  config["SnpEffDB"]

SAMPLES  = [line.strip() for line in open(config['samples'])] # samples
ADAPTER  = ""

# get the interval name dict
interval_name_dict = build_interval(CHRFILE, VCFJOBS)
# How to deal with duplication
DEDUPLICATE = config['DEDUPLICATE']
readfilter = "--disable-read-filter NotDuplicateReadFilter"

if DEDUPLICATE:
    readfilter = "--read-filter NotDuplicateReadFilter"

# make workding directory
RAW_DIR   = join("analysis", "00-raw-data")
CLEAN_DIR = join("analysis", "01-clean-data")
ALIGN_DIR = join("analysis", "02-read-alignment")
VCF_DIR   = join("analysis", "03-vcf-calling")
FNL_DIR   = join("analysis", "04-final-vcf")
QC_DIR    = join("report", "fastp")
TMPDIR    = join("analysis", "temp")
LOGDIR    = join("analysis", "log")

shell("mkdir -p {CLEAN_DIR} {ALIGN_DIR} {VCF_DIR} {FNL_DIR}")
shell("mkdir -p {QC_DIR} {TMPDIR} {LOGDIR}")

if DEDUPLICATE:
    ALL_BAM  = [join(ALIGN_DIR, "{}_markdup_sorted.bam".format(sample)) for sample in SAMPLES]
else:
    ALL_BAM  = [join(ALIGN_DIR, "{}_sorted.bam".format(sample)) for sample in SAMPLES]


ALL_BCF  = [join(VCF_DIR, "{}.bcf".format(chr)) for chr in list(interval_name_dict.keys()) ]
MERGED_BCF = join(FNL_DIR, "merged.bcf")
#VCFS     = join(VCF_DIR, "freebayes.var.vcf") # FreeBayes

if config['Annotation']:
    FINAL_VCF = join(FNL_DIR, "anno.vcf.gz")
else:
    FINAL_VCF = MERGED_BCF



rule all:
    input:
        FINAL_VCF

rule data_clean:
    input:
        r1 = join(RAW_DIR, "{sample}_1.fq.gz"),
        r2 = join(RAW_DIR, "{sample}_2.fq.gz")
    params:
        adapter = ADAPTER,
        adapter_param = lambda wildcards, resources: '-a ' + ADAPTER if ADAPTER else '',
        length  = 50,
        quality = 20,
        prefix  = lambda wildcards: join(QC_DIR, wildcards.sample)
    output:
        r1 = join(CLEAN_DIR, "{sample}_1.fq.gz"),
        r2 = join(CLEAN_DIR, "{sample}_2.fq.gz")
    log: join(LOGDIR, "{sample}_fastp.log")
    threads: 4
    conda: "env.yaml"
    message: "----- processing {input.r1} and {input.r2} with fastp ------"
    shell: """
        fastp {params.adapter_param} \
        -w {threads} -j {params.prefix}.json -h {params.prefix}.html -l {params.length} -q {params.quality}\
        -i {input.r1} -I {input.r2} -o {output.r1} -O {output.r2}
    """


rule bwa_mem_deduplicate:
    input:
        r1 = join(CLEAN_DIR, "{sample}_1.fq.gz"),
        r2 = join(CLEAN_DIR, "{sample}_2.fq.gz")
    params:
        index  = INDEX,
        rg     = r"@RG\tID:{sample}\tSM:{sample}\tPL:ILLUMINA",
        tmpdir = TMPDIR,
        logdir = LOGDIR,
        split  = join(ALIGN_DIR,"{sample}_split.sam"),
        disc   = join(ALIGN_DIR,"{sample}_disc.sam")
    output:
        join(ALIGN_DIR,"{sample}_markdup_sorted.bam"),
    log:
        log1=join(LOGDIR, "{sample}_align_warning.log"),
        log2=join(LOGDIR, "{sample}_samblaster.log")
    conda: "env.yaml"
    threads: 20
    message: "align {input.r1} and {input.r2} to {params.index} with {threads} threads"
    shell:
        """
        bwa mem -v 2 -t {threads} -R '{params.rg}' {params.index} {input.r1} {input.r2} 2> {log.log1} |\
        samblaster --acceptDupMarks --addMateTags --splitterFile {params.split} --discordantFile {params.disc} 2> {log.log2} |\
        sambamba view -S -f bam -l 0 /dev/stdin |\
        sambamba sort -t {threads} -m 2G --tmpdir {params.tmpdir} -o {output} /dev/stdin
        """

rule bwa_mem_no_deduplicate:
    input:
        r1 = join(CLEAN_DIR, "{sample}_1.fq.gz"),
        r2 = join(CLEAN_DIR, "{sample}_2.fq.gz")
    params:
        index  = INDEX,
        rg     = r"@RG\tID:{sample}\tSM:{sample}\tPL:ILLUMINA",
        tmpdir = TMPDIR,
        logdir = LOGDIR
    output:
        join(ALIGN_DIR,"{sample}_sorted.bam"),
    log:
        join(LOGDIR, "{sample}_align_warning.log")
    conda: "env.yaml"
    threads: 20
    message: "align {input.r1} and {input.r2} to {params.index} with {threads} threads"
    shell:
        """
        bwa mem -v 2 -t {threads} -R '{params.rg}' {params.index} {input.r1} {input.r2} 2> {log} |\
        sambamba view -S -f bam -l 0 /dev/stdin |\
        sambamba sort -t {threads} -m 2G --tmpdir {params.tmpdir} -o {output} /dev/stdin
        """

rule singleVCF:
    input:
        ALL_BAM
    params:
        ref       = FASTA,
        interval = lambda wildcards: ",".join(interval_name_dict[int(wildcards.chr)])
    output:
       join(VCF_DIR, "{chr}.bcf")
    threads: 6 
    conda: "env.yaml"
    shell:"""
	bcftools mpileup -f {params.ref} \
        --redo-BAQ --min-BQ 30 \
        --per-sample-mF \
        --annotate FORMAT/AD,FORMAT/DP \
        --regions {params.interval} \
        -Ou {input}  | \
        bcftools call -mv -Ob -o {output}
	"""

rule MergeVCFS:
    input:
        ALL_BCF
    output:
        MERGED_BCF
    threads: 1
    shell:"""
	bcftools concat --naive -o {output} {input}
    """

rule BCFtoVCF:
   input:
       MERGED_BCF
    output:
        join(FNL_DIR, "merged.vcf")
    conda: "env.yaml"
    shell:"""
        bcftools view {input} > {output}
    """

rule snpeff:
    input:
        join(FNL_DIR, "merged.vcf")
    params:
        db = SnpEffDB,
        outdir = "report/snpeff"
    output:
       vcf   = join(FNL_DIR, "anno.vcf.gz")
       index   = join(FNL_DIR, "anno.vcf.gz.csi")
    message: "annotate with SnpEff"
    shell:"""
        mkdir -p {params.outdir}
        java -Xmx4G -jar {snpeff} ann -o gatk -s {params.outdir}/snpeff_summary.html {params.db} {input} | \
        bcftools view  -Oz -o {output.vcf}
        bcftools index {output.vcf}
        """